/*
Dufwenberg & Dufwenberg's Perceived Cheating Aversion model 
Modified to accept data from Stata variables instead of locals
*/

mata
mata clear
void mse_function(todo, theta, y, g, negH) {
    // Get observed data from Stata variable
    actual = st_data(., "reports")
    
    // Calculate one_eta values
    one_eta = J(6, 1, .)
    one_eta[1] = 2/theta 
    for (i = 2; i <= 6; i++) {
        one_eta[i] = 2/((i-1)*(theta-2) + theta)
    }
    
    // Constants
    H_chance = 1/6
    lu_chance = (0, 1, 2, 3, 4, 5)' / 6
    
    // Calculate lying up fractions (from highest to lowest)
    lu_frac = J(6, 1, .)
    lu_frac[6] = one_eta[6]
    for (i = 5; i >= 1; i--) {
        product = 1
        for (j = i+1; j <= 6; j++) {
            product = product * (1 - one_eta[j])
        }
        lu_frac[i] = one_eta[i] * product
    }
    
    // Calculate honest fractions
    H_frac = J(6, 1, .)
    H_frac[6] = 1
    for (i = 5; i >= 1; i--) {
        H_frac[i] = 1 - sum(lu_frac[i+1::6])
    }
    
    // Predicted distribution
    sx_hat = H_chance * H_frac + lu_frac :* lu_chance
    
    // Store predictions back to Stata
    st_store(., "dd_hat", sx_hat)
    
    // MSE calculation
    y = sum(((actual - sx_hat)*100):^2) * (6)
}

// Optimization
S1 = optimize_init()
optimize_init_technique(S1, "nr")
optimize_init_which(S1, "min")
optimize_init_evaluator(S1, &mse_function())
optimize_init_params(S1, 3)
theta_hat = optimize(S1)
mse_val = optimize_result_value(S1) * 100

// Store results in Stata
st_numscalar("theta_hat", theta_hat)
st_numscalar("mse", mse_val)

// Also store theta as a variable (matching your original approach)
report = st_data(., "reports")
thetavec = J(rows(report), 1, theta_hat)
st_store(., "theta_hat", thetavec)


printf("Estimate of Theta is %f\n", theta_hat)
printf("Corresponding MSE is %f\n", mse_val)
printf("RMSE is %f\n", sqrt(mse_val))

// Verify predictions sum to 1
dd_hat = st_data(., "dd_hat")
sum_check = sum(dd_hat)
printf("Sum of predicted probabilities: %f\n", sum_check)
end

// Display results
*display "Estimate of Theta is " theta_hat
*display "Corresponding MSE is " mse
*display "RMSE is " sqrt(mse)

// Show comparison of actual vs predicted
*display "Actual vs Predicted:"
*list reports dd_hat, sep(0)
